/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "DTLSProtocol.h"
#include <iomanip>

namespace aiengine {

bool DTLSProtocol::check(const Packet &packet) {

	int length = packet.getLength();

	if (length >= header_size) {
		setHeader(packet.getPayload());

		if ((header_->type >= DTLS_CT_HANDSHAKE)and(header_->type <= DTLS_CT_APPLICATION_DATA)) {
			++total_valid_packets_;
			return true;
		}
	}
	++total_invalid_packets_;
	return false;
}

uint64_t DTLSProtocol::getCurrentUseMemory() const {

        uint64_t mem = sizeof(DTLSProtocol);

        mem += info_cache_->getCurrentUseMemory();

        return mem;
}

uint64_t DTLSProtocol::getAllocatedMemory() const {

        uint64_t mem = sizeof(DTLSProtocol);

        mem += info_cache_->getAllocatedMemory();

        return mem;
}

uint64_t DTLSProtocol::getTotalAllocatedMemory() const {

	return getAllocatedMemory();
}

void DTLSProtocol::setDynamicAllocatedMemory(bool value) {

        info_cache_->setDynamicAllocatedMemory(value);
}

bool DTLSProtocol::isDynamicAllocatedMemory() const {

        return info_cache_->isDynamicAllocatedMemory();
}

void DTLSProtocol::releaseCache() {

        if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
                auto ft = fm->getFlowTable();

                std::ostringstream msg;
                msg << "Releasing " << name() << " cache";

                infoMessage(msg.str());

                uint64_t total_bytes_released_by_flows = 0;
                uint32_t release_flows = 0;

                for (auto &flow: ft) {
                        if (SharedPointer<DTLSInfo> info = flow->getDTLSInfo(); info) {
                                total_bytes_released_by_flows += sizeof(info);

                                flow->layer7info.reset();
                                ++release_flows;
                                info_cache_->release(info);
                        }
                }
                std::string funit = "Bytes";

		data_time_ = boost::posix_time::microsec_clock::local_time();

                unitConverter(total_bytes_released_by_flows, funit);

                msg.str("");
                msg << "Release " << release_flows << " flows";
                msg << ", flow " << total_bytes_released_by_flows << " " << funit;
                infoMessage(msg.str());
        }
}

void DTLSProtocol::releaseFlowInfo(Flow *flow) {

        if (auto info = flow->getDTLSInfo(); info)
                info_cache_->release(info);
}

void DTLSProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	int length = flow->packet->getLength();
	total_bytes_ += length;
	++total_packets_;
	++flow->total_packets_l7;

	current_flow_ = flow;

	if (length >= header_size) {
                SharedPointer<DTLSInfo> info = flow->getDTLSInfo();
                if (!info) {
                        if (info = info_cache_->acquire(); !info) {
                                logFailCache(info_cache_->name(), flow);
                                return;
                        }
                        flow->layer7info = info;
                }

		setHeader(flow->packet->getPayload());

                int record_length = ntohs(header_->length);

                if (record_length > 0) {
                	const dtls_header *record = header_;
                	const uint8_t *payload = flow->packet->getPayload();
                        int offset = 0;         // Total offset byte
                        int maxattemps = 0;     // For prevent invalid decodings

			info->setVersion(ntohs(record->version));
			info->setEpoch(ntohs(record->epoch));

                        do {
                                [[maybe_unused]] uint16_t version = ntohs(record->version);
                                short type = record->type;
                                record_length = ntohs(record->length);
                                ++maxattemps;
#ifdef DEBUG
                                std::cout << __FILE__ << ":" << __func__ << ":len:" << length << " rlen:" << record_length;
                                std::cout << " type: " << int(type) << " offset:" << offset;
				std::cout << " htype:" << short(record->data[0]) << std::endl;
#endif
                                if (type == DTLS_CT_HANDSHAKE) {
					short htype = record->data[0];

					if (htype == DTLS_MT_CLIENT_HELLO)
						++total_client_hellos_;
					else if (htype == DTLS_MT_HELLO_VERIFY)
						++total_hello_verifies_requests_;
					else if (htype == DTLS_MT_SERVER_HELLO)
						++total_server_hellos_;
					else if (htype == DTLS_MT_CERTIFICATE)
						++total_certificates_;
                        		else if (htype == DTLS_MT_SERVER_KEY_EXCHANGE)
                                		++total_server_key_exchanges_;
                        		else if (htype == DTLS_MT_NEW_SESSION_TICKET)
                                		++total_new_session_tickets_;
                        		else if (htype == DTLS_MT_CERTIFICATE_REQUEST)
                                		++total_certificate_requests_;
					else if (htype == DTLS_MT_SERVER_DONE)
						++total_server_dones_;
                        		else if (htype == DTLS_MT_CERTIFICATE_VERIFY)
                                		++total_certificate_verifies_;
                        		else if (htype == DTLS_MT_CLIENT_KEY_EXCHANGE)
                                		++total_client_key_exchanges_;

                                        // The handshake could be encrypted
                                        if (info->isEncrypted() == false)
                                                ++total_handshakes_;
                                        else
                                        	++total_encrypted_handshakes_;

                           	} else if (type == DTLS_CT_CHANGE_CIPHER_SPEC) {
                                	++total_change_cipher_specs_;
                                        info->setEncrypted(true); // From this point all should be encrypted
                                } else if (type == DTLS_CT_APPLICATION_DATA) { // On Tls1.3 encrypted data can be sent
                                        ++total_data_;
                                        info->incDataPdus();
                                }

                                ++total_records_;

                                offset += record_length + sizeof(dtls_header);

                                if (maxattemps == 5) break; // Maximum Pdus per packet allowed
                        	record = reinterpret_cast<const dtls_header*>(&payload[offset]);
                	} while (offset + (int)sizeof(dtls_header) < length);
        	}
	} else {
                //if (flow->getPacketAnomaly() == PacketAnomalyType::NONE)
                //        flow->setPacketAnomaly(PacketAnomalyType::DTLS_BOGUS_HEADER);

                //anomaly_->incAnomaly(PacketAnomalyType::DTLS_BOGUS_HEADER);
	}
}

void DTLSProtocol::statistics(std::basic_ostream<char>& out, int level, int32_t limit) const {

	showStatisticsHeader(out, level);

        if (level > 3) {
		out << "\t" << "Total handshakes:       " << std::setw(10) << total_handshakes_ << "\n"
			<< "\t" << "Total encrypt handshakes:" << std::setw(9) << total_encrypted_handshakes_ << "\n"
			<< "\t" << "Total alerts:           " << std::setw(10) << total_alerts_ << "\n"
			<< "\t" << "Total change cipher specs:" << std::setw(8) << total_change_cipher_specs_ << "\n"
			<< "\t" << "Total data:             " << std::setw(10) << total_data_ << "\n"
			<< "\t" << "Total client hellos:    " << std::setw(10) << total_client_hellos_ << "\n"
			<< "\t" << "Total hello verifies:   " << std::setw(10) << total_hello_verifies_requests_ << "\n"
			<< "\t" << "Total server hellos:    " << std::setw(10) << total_server_hellos_ << "\n"
			<< "\t" << "Total certificates:     " << std::setw(10) << total_certificates_ << "\n"
			<< "\t" << "Total server key exs:   " << std::setw(10) << total_server_key_exchanges_ << "\n"
			<< "\t" << "Total certificate reqs: " << std::setw(10) << total_certificate_requests_ << "\n"
			<< "\t" << "Total server dones:     " << std::setw(10) << total_server_dones_ << "\n"
			<< "\t" << "Total certificates vers:" << std::setw(10) << total_certificate_verifies_ << "\n"
			<< "\t" << "Total client key exs:   " << std::setw(10) << total_client_key_exchanges_ << "\n"
			<< "\t" << "Total new session tickets:" << std::setw(8) << total_new_session_tickets_ << "\n"
			<< "\t" << "Total handshakes finish:" << std::setw(10) << total_handshake_finishes_ << "\n"
			<< "\t" << "Total records:          " << std::setw(10) << total_records_ << std::endl;
        }
	if ((level > 5)and(flow_forwarder_.lock()))
		flow_forwarder_.lock()->statistics(out);
        if (level > 3)
                info_cache_->statistics(out);
}

void DTLSProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

        if (level > 3) {
                out["handshakes"] = total_handshakes_;
                out["encrypt handshakes"] = total_encrypted_handshakes_;
                out["alerts"] = total_alerts_;
                out["change cipher specs"] = total_change_cipher_specs_;
                out["data"] = total_data_;

                Json j;

                j["client hellos"] = total_client_hellos_;
                j["hello verifies"] = total_hello_verifies_requests_;
                j["server hellos"] = total_server_hellos_;
                j["certificates"] = total_certificates_;
                j["server keys"] = total_server_key_exchanges_;
                j["certificate requests"] = total_certificate_requests_;
        	j["server dones"] = total_server_dones_;
        	j["certificate verifies"] = total_certificate_verifies_;
                j["client keys"] = total_client_key_exchanges_;
                j["handshake finish"] = total_handshake_finishes_;
                j["new session tickets"] = total_new_session_tickets_;
                j["records"] = total_records_;

                out["types"] = j;
        }
}

CounterMap DTLSProtocol::getCounters() const {
       	CounterMap cm;

        cm.addKeyValue("packets", total_packets_);
        cm.addKeyValue("bytes", total_bytes_);

        cm.addKeyValue("handshakes", total_handshakes_);
        cm.addKeyValue("encrypt handshakes", total_encrypted_handshakes_);
        cm.addKeyValue("alerts", total_alerts_);
        cm.addKeyValue("change cipher specs", total_change_cipher_specs_);
        cm.addKeyValue("datas", total_data_);

        cm.addKeyValue("client hellos", total_client_hellos_);
        cm.addKeyValue("hello verifies", total_hello_verifies_requests_);
        cm.addKeyValue("server hellos", total_server_hellos_);
        cm.addKeyValue("certificates", total_certificates_);
        cm.addKeyValue("server key exchanges", total_server_key_exchanges_);
        cm.addKeyValue("certificate requests", total_certificate_requests_);
        cm.addKeyValue("server dones", total_server_dones_);
        cm.addKeyValue("certificate verifies", total_certificate_verifies_);
        cm.addKeyValue("client key exchanges", total_client_key_exchanges_);
        cm.addKeyValue("handshake dones", total_handshake_finishes_);
        cm.addKeyValue("new session tickets", total_new_session_tickets_);
        cm.addKeyValue("records", total_records_);

        return cm;
}

void DTLSProtocol::resetCounters() {

	reset();

        total_handshakes_ = 0;
        total_encrypted_handshakes_ = 0;
        total_alerts_ = 0;
        total_change_cipher_specs_ = 0;
        total_data_ = 0;

        total_client_hellos_ = 0;
        total_hello_verifies_requests_ = 0;
        total_server_hellos_ = 0;
        total_certificates_ = 0;
        total_certificate_requests_ = 0;
        total_certificate_verifies_ = 0;
        total_server_dones_ = 0;
        total_server_key_exchanges_ = 0;
        total_client_key_exchanges_ = 0;
        total_handshake_finishes_ = 0;
        total_new_session_tickets_ = 0;
        total_records_ = 0;
}

void DTLSProtocol::increaseAllocatedMemory(int value) {

        info_cache_->create(value);
}

void DTLSProtocol::decreaseAllocatedMemory(int value) {

        info_cache_->destroy(value);
}

} // namespace aiengine
